{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MSCGraph 简介"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm-book/doc/read/msc\n"
     ]
    }
   ],
   "source": [
    "%cd ..\n",
    "from pathlib import Path\n",
    "\n",
    "temp_dir = Path(\".temp\")\n",
    "temp_dir.mkdir(exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MSCGraph 是 MSC 的核心，它对编译器的作用类似于 IR（中间表示）。MSCGraph 是 `Relax` 的 DAG（有向无环图）格式。构建 MSCGraph 的目标是使压缩算法的开发和权重管理（这在训练时很重要）更加容易。如果所选的运行时目标不支持所有的 Calls，那么 Relax/Relay 模块将拥有多个 MSCGraphs。MSCGraph 存在于编译过程各个阶段，用于管理模型计算信息。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Overwriting graph/model.py\n"
     ]
    }
   ],
   "source": [
    "%%file graph/model.py\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch import fx\n",
    "import tvm\n",
    "from tvm import relax\n",
    "from tvm.relax.frontend.torch import from_fx\n",
    "\n",
    "class M(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = torch.nn.Conv2d(3, 6, 1, bias=False)\n",
    "        self.relu = torch.nn.ReLU()\n",
    "\n",
    "    def forward(self, data):\n",
    "        x = self.conv(data)\n",
    "        return self.relu(x)\n",
    "\n",
    "def get_model(input_info):\n",
    "    # 转换前端模型为 IRModule\n",
    "    with torch.no_grad():\n",
    "        torch_fx_model = fx.symbolic_trace(M())\n",
    "        mod = from_fx(torch_fx_model, input_info, keep_params_as_input=False)\n",
    "    # mod, params = relax.frontend.detach_params(mod)\n",
    "    return mod, torch_fx_model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 构建 MSC 计算图"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.contrib.msc.core.frontend import translate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "从 `relax` 构建 msc 计算图："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div class=\"highlight\" style=\"background: \"><pre style=\"line-height: 125%;\"><span></span><span style=\"color: #007979; font-style: italic\"># from tvm.script import ir as I</span>\n",
       "<span style=\"color: #007979; font-style: italic\"># from tvm.script import relax as R</span>\n",
       "\n",
       "<span style=\"color: #AA22FF\">@I</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>ir_module\n",
       "<span style=\"color: #008000; font-weight: bold\">class</span> <span style=\"color: #0000FF; font-weight: bold\">Module</span>:\n",
       "    <span style=\"color: #AA22FF\">@R</span><span style=\"color: #AA22FF; font-weight: bold\">.</span>function\n",
       "    <span style=\"color: #008000; font-weight: bold\">def</span> <span style=\"color: #0000FF\">main</span>(inp_0: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">3</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)) <span style=\"color: #AA22FF; font-weight: bold\">-&gt;</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">6</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>):\n",
       "        <span style=\"color: #008000; font-weight: bold\">with</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>dataflow():\n",
       "            lv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">6</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>conv2d(inp_0, metadata[<span style=\"color: #BA2121\">&quot;relax.expr.Constant&quot;</span>][<span style=\"color: #008000\">0</span>], strides<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], padding<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>, <span style=\"color: #008000\">0</span>], dilation<span style=\"color: #AA22FF; font-weight: bold\">=</span>[<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">1</span>], groups<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #008000\">1</span>, data_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, kernel_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;OIHW&quot;</span>, out_layout<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;NCHW&quot;</span>, out_dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>)\n",
       "            lv1: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">6</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> R<span style=\"color: #AA22FF; font-weight: bold\">.</span>nn<span style=\"color: #AA22FF; font-weight: bold\">.</span>relu(lv)\n",
       "            gv: R<span style=\"color: #AA22FF; font-weight: bold\">.</span>Tensor((<span style=\"color: #008000\">1</span>, <span style=\"color: #008000\">6</span>, <span style=\"color: #008000\">4</span>, <span style=\"color: #008000\">4</span>), dtype<span style=\"color: #AA22FF; font-weight: bold\">=</span><span style=\"color: #BA2121\">&quot;float32&quot;</span>) <span style=\"color: #AA22FF; font-weight: bold\">=</span> lv1\n",
       "            R<span style=\"color: #AA22FF; font-weight: bold\">.</span>output(gv)\n",
       "        <span style=\"color: #008000; font-weight: bold\">return</span> gv\n",
       "\n",
       "<span style=\"color: #007979; font-style: italic\"># Metadata omitted. Use show_meta=True in script() method to show it.</span>\n",
       "</pre></div>\n"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "main <INPUTS: inp_0:0| OUTPUTS: relu:0>\n",
      "N_0 inp_0 <PARENTS: | CHILDERN: conv2d>\n",
      "  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
      "  OPTYPE: input\n",
      "\n",
      "N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>\n",
      "  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
      "  OUT: conv2d:0<1,6,4,4|float32|NCHW>\n",
      "  OPTYPE: nn.conv2d\n",
      "  SCOPE: block\n",
      "  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW \n",
      "  WEIGHTS: \n",
      "    weight: const<6,3,1,1|float32|OIHW>\n",
      "\n",
      "N_2 relu <PARENTS: conv2d| CHILDERN: >\n",
      "  IN: conv2d:0<1,6,4,4|float32|NCHW>\n",
      "  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>\n",
      "  OPTYPE: nn.relu\n",
      "  SCOPE: block\n",
      "\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from graph.model import get_model\n",
    "input_info = [((1, 3, 4, 4), \"float32\")] # 给定输入 shape 和数据类型\n",
    "mod, _ = get_model(input_info)\n",
    "mod.show()\n",
    "graph, weights = translate.from_relax(mod)\n",
    "print(graph)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tvm.contrib.msc.core.ir.graph.MSCGraph"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(graph)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{topic} MSCGraph 和 Relex 之间的差异\n",
    "- MSCGraph 具有 DAG（有向无环图）格式，而 Relax 具有表达式格式。\n",
    "- MSCGraph 将张量分类为输入和权重，而 Relex 将张量定义为变量和常量。\n",
    "- MSCGraph 使用节点名称（如 `conv1`，`layer1.conv1` ...）作为搜索节点的主要 ID，而 Relax 使用带有前缀的索引（如 `lvXX`，`gv`）。\n",
    "```\n",
    "\n",
    "导出序列化文件以加载计算图："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"name\": \"main\", \n",
      "  \"inputs\": [\n",
      "    \"inp_0:0\"\n",
      "  ], \n",
      "  \"outputs\": [\n",
      "    \"relu:0\"\n",
      "  ], \n",
      "  \"nodes\": [\n",
      "    {\n",
      "      \"index\": 0, \n",
      "      \"name\": \"inp_0\", \n",
      "      \"shared_ref\": \"\", \n",
      "      \"optype\": \"input\", \n",
      "      \"parents\": [], \n",
      "      \"inputs\": [], \n",
      "      \"outputs\": [\n",
      "        {\n",
      "          \"name\": \"inp_0:0\", \n",
      "          \"alias\": \"inp_0\", \n",
      "          \"dtype\": \"float32\", \n",
      "          \"layout\": \"NCHW\", \n",
      "          \"shape\": [1, 3, 4, 4], \n",
      "          \"prims\": []\n",
      "        }\n",
      "      ], \n",
      "      \"attrs\": {}, \n",
      "      \"weights\": {}\n",
      "    }, \n",
      "    {\n",
      "      \"index\": 1, \n",
      "      \"name\": \"conv2d\", \n",
      "      \"shared_ref\": \"\", \n",
      "      \"optype\": \"nn.conv2d\", \n",
      "      \"parents\": [\n",
      "        \"inp_0\"\n",
      "      ], \n",
      "      \"inputs\": [\n",
      "        \"inp_0:0\"\n",
      "      ], \n",
      "      \"outputs\": [\n",
      "        {\n",
      "          \"name\": \"conv2d:0\", \n",
      "          \"alias\": \"\", \n",
      "          \"dtype\": \"float32\", \n",
      "          \"layout\": \"NCHW\", \n",
      "          \"shape\": [1, 6, 4, 4], \n",
      "          \"prims\": []\n",
      "        }\n",
      "      ], \n",
      "      \"attrs\": {\n",
      "        \"out_layout\": \"NCHW\", \n",
      "        \"data_layout\": \"NCHW\", \n",
      "        \"padding\": \"0,0,0,0\", \n",
      "        \"groups\": \"1\", \n",
      "        \"kernel_layout\": \"OIHW\", \n",
      "        \"strides\": \"1,1\", \n",
      "        \"dilation\": \"1,1\", \n",
      "        \"out_dtype\": \"float32\"\n",
      "      }, \n",
      "      \"weights\": {\"weight\": {\n",
      "          \"name\": \"const\", \n",
      "          \"alias\": \"\", \n",
      "          \"dtype\": \"float32\", \n",
      "          \"layout\": \"OIHW\", \n",
      "          \"shape\": [6, 3, 1, 1], \n",
      "          \"prims\": []}}\n",
      "    }, \n",
      "    {\n",
      "      \"index\": 2, \n",
      "      \"name\": \"relu\", \n",
      "      \"shared_ref\": \"\", \n",
      "      \"optype\": \"nn.relu\", \n",
      "      \"parents\": [\n",
      "        \"conv2d\"\n",
      "      ], \n",
      "      \"inputs\": [\n",
      "        \"conv2d:0\"\n",
      "      ], \n",
      "      \"outputs\": [\n",
      "        {\n",
      "          \"name\": \"relu:0\", \n",
      "          \"alias\": \"relu\", \n",
      "          \"dtype\": \"float32\", \n",
      "          \"layout\": \"NCHW\", \n",
      "          \"shape\": [1, 6, 4, 4], \n",
      "          \"prims\": []\n",
      "        }\n",
      "      ], \n",
      "      \"attrs\": {}, \n",
      "      \"weights\": {}\n",
      "    }\n",
      "  ], \n",
      "  \"prims\": []\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(graph.to_json())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "导出用于可视化的 prototxt 文件："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "name: \"main\"\n",
      "layer {\n",
      "  name: \"inp_0\"\n",
      "  type: \"input\"\n",
      "  top: \"inp_0\"\n",
      "  layer_param {\n",
      "    idx: 0\n",
      "    output_0: \"inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\"\n",
      "  }\n",
      "}\n",
      "layer {\n",
      "  name: \"conv2d\"\n",
      "  type: \"nn_conv2d\"\n",
      "  top: \"conv2d\"\n",
      "  bottom: \"inp_0\"\n",
      "  layer_param {\n",
      "    out_layout: \"NCHW\"\n",
      "    out_dtype: \"float32\"\n",
      "    groups: \"1\"\n",
      "    kernel_layout: \"OIHW\"\n",
      "    param_weight: \"const<6,3,1,1|float32|OIHW>\"\n",
      "    strides: \"1,1\"\n",
      "    idx: 1\n",
      "    padding: \"0,0,0,0\"\n",
      "    data_layout: \"NCHW\"\n",
      "    dilation: \"1,1\"\n",
      "    output_0: \"conv2d:0<1,6,4,4|float32|NCHW>\"\n",
      "    input_0: \"inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\"\n",
      "  }\n",
      "}\n",
      "layer {\n",
      "  name: \"relu\"\n",
      "  type: \"nn_relu\"\n",
      "  top: \"relu\"\n",
      "  bottom: \"conv2d\"\n",
      "  layer_param {\n",
      "    idx: 2\n",
      "    input_0: \"conv2d:0<1,6,4,4|float32|NCHW>\"\n",
      "    output_0: \"relu:0(relu)<1,6,4,4|float32|NCHW>\"\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(graph.visualize(f\"{temp_dir}/graph.prototxt\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通过将 `relax/relay` 结构转换成 MSCGraph，可以使用 DAG 风格的方法进行节点的查找与遍历："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "tags": [
     "hide-output"
    ]
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "conv2d 节点 N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>\n",
      "  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
      "  OUT: conv2d:0<1,6,4,4|float32|NCHW>\n",
      "  OPTYPE: nn.conv2d\n",
      "  SCOPE: block\n",
      "  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW \n",
      "  WEIGHTS: \n",
      "    weight: const<6,3,1,1|float32|OIHW>\n",
      "\n",
      "input 节点 inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
      "核心 conv 节点 N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>\n",
      "  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
      "  OUT: conv2d:0<1,6,4,4|float32|NCHW>\n",
      "  OPTYPE: nn.conv2d\n",
      "  SCOPE: block\n",
      "  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW \n",
      "  WEIGHTS: \n",
      "    weight: const<6,3,1,1|float32|OIHW>\n",
      "\n",
      "父节点 N_0 inp_0 <PARENTS: | CHILDERN: conv2d>\n",
      "  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
      "  OPTYPE: input\n",
      "N_0 inp_0 <PARENTS: | CHILDERN: conv2d>\n",
      "  OUT: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
      "  OPTYPE: input\n",
      "\n",
      "子节点 N_2 relu <PARENTS: conv2d| CHILDERN: >\n",
      "  IN: conv2d:0<1,6,4,4|float32|NCHW>\n",
      "  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>\n",
      "  OPTYPE: nn.relu\n",
      "  SCOPE: block\n",
      "\n"
     ]
    }
   ],
   "source": [
    "for node in graph.get_nodes():\n",
    "    if node.optype == \"nn.conv2d\":\n",
    "        print(f\"conv2d 节点 {node}\")\n",
    "\n",
    "for i in graph.get_inputs():\n",
    "    print(f\"input 节点 {i}\")\n",
    "\n",
    "node = graph.find_node(\"conv2d\")\n",
    "print(f\"核心 conv 节点 {node}\")\n",
    "for p in node.parents:\n",
    "    print(f\"父节点 {p}\" + str(p))\n",
    "for c in node.children:\n",
    "    print(f\"子节点 {c}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "MSCGraph 的基本结构和 onnx 类似：一个 MSCGraph 中包含多个 MSCJoint（计算节点）和 MSCTensor（数据）。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MSCJoint\n",
    "\n",
    "作用等同于 `relax.Expr`，`torch.nn.Module`，`onnx.Node` 等，即一个计算图中计算逻辑的最小表达单元。一个 `MSCJoint` 对应一个 `relax.Expr` 或 `Function`（如果使用了一些 `pattern` 做分图，例如 `conv2d+bias fuse` 成 `conv2d_bais`），`MSCJoint` 和 `Expr` 的区别在于 `MSCJoint` 包含更多拓扑信息，不仅仅有 `relax.Call.args` 对应的 `inputs`，也包括 `parents`、`children` 以及 `outputs` ，可以更方便的获取拓扑关系。\n",
    "\n",
    "以下为 `MSCJoint` 的描述示例："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tvm.contrib.msc.core.ir.graph.MSCJoint,\n",
       " N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>\n",
       "   IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
       "   OUT: conv2d:0<1,6,4,4|float32|NCHW>\n",
       "   OPTYPE: nn.conv2d\n",
       "   SCOPE: block\n",
       "   ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW \n",
       "   WEIGHTS: \n",
       "     weight: const<6,3,1,1|float32|OIHW>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "node = graph.find_node(\"conv2d\")\n",
    "type(node), node"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "其中包含几个重要成员：\n",
    "\n",
    "- `ID(ID_1)`：节点的 `index`，也是遍历计算图的时候使用的排序数值\n",
    "- NAME(`conv2d`)：节点名称，每个节点有唯一的 `name`，用来查找节点\n",
    "- PARENTS/CHILDREN(inp_0/batch_norm)：节点的拓扑关系\n",
    "- `ATTRS`：节点属性，为了向后兼容都是用 string 类型保存，codegen 的时候会 cast 成对应类型\n",
    "- IN/OUT(relu)：节点输入输出，每个节点有1到多个outputs，每个output都是一个 MSCTensor；Inputs 则为 parents 的 outputs 的引用\n",
    "- WEIGHTS：节点weights，每个 weights 是一个 `ref:MSCTensor` 对，`ref` 表示 weight 类型，如 `\"weight\",\"bias\",\"gamma\"` 等，定义ref的原因主要是考虑到模型压缩针对不同的weight类型操作不同，故需要对weight进行分类"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MSCTensor\n",
    "这种数据结构在relax中并未体现，可以理解为NDArray的抽象。MSCTensor描述每个节点outputs以及weights的信息。通过MSCTensor可以查找到producer和consumers，方便对tensor进行操作时获取上下文。MSCTensor格式设计参照了tensorflow的tensor。以下为一个MSCTensor的描述，包含几个重要属性："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tvm.contrib.msc.core.ir.graph.MSCTensor, [conv2d:0<1,6,4,4|float32|NCHW>])"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "type(node.outputs[0]), node.outputs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "- Name(conv2d:0)：tensor名字，格式为节点名称:数字格式，同tensorflow相同。以这种标记可以查找tensor的producer\n",
    "- Shape(1,6,4,4)：tensor的shape，动态维度用-1表示\n",
    "- Dtype(float32)：tensor的数据类型\n",
    "- Layout(NCHW)：MSC中新加的属性，tensor的数据排布格式。这一属性在剪枝过程中比较重要，对一些计算过程算子的优化也起到参考作用"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "producer N_1 conv2d <PARENTS: inp_0| CHILDERN: relu>\n",
      "  IN: inp_0:0(inp_0)<1,3,4,4|float32|NCHW>\n",
      "  OUT: conv2d:0<1,6,4,4|float32|NCHW>\n",
      "  OPTYPE: nn.conv2d\n",
      "  SCOPE: block\n",
      "  ATTRS: out_dtype=float32 strides=1,1 kernel_layout=OIHW groups=1 padding=0,0,0,0 data_layout=NCHW dilation=1,1 out_layout=NCHW \n",
      "  WEIGHTS: \n",
      "    weight: const<6,3,1,1|float32|OIHW>\n",
      "\n",
      "has consumer N_2 relu <PARENTS: conv2d| CHILDERN: >\n",
      "  IN: conv2d:0<1,6,4,4|float32|NCHW>\n",
      "  OUT: relu:0(relu)<1,6,4,4|float32|NCHW>\n",
      "  OPTYPE: nn.relu\n",
      "  SCOPE: block\n",
      "\n"
     ]
    }
   ],
   "source": [
    "output = node.output_at(0)\n",
    "\n",
    "print(f\"producer {graph.find_producer(output.name)}\")\n",
    "for c in graph.find_consumers(output.name):\n",
    "    print(f\"has consumer {c}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
